{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6a08763a-aed6-4f91-94d0-80a3c0e2665b",
   "metadata": {},
   "source": [
    "### Weeks 2 - Day 2 - Gradio Chatbot with LiteLLM (Model Routing)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a4f38c58-5ceb-4d5e-b538-c1acdc881f73",
   "metadata": {},
   "source": [
    "**Author** : [Marcus Rosen](https://github.com/MarcusRosen)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "36f4814a-2bfc-4631-97d7-7a474fa1cc8e",
   "metadata": {},
   "source": [
    "[LiteLLM](https://docs.litellm.ai/docs/) provides the abilitty to call different LLM providers via a unified interface, returning results in OpenAI compatible formats.\n",
    "\n",
    "Features:\n",
    "- Model Selection in Gradio (Anthropic, OpenAI, Gemini)\n",
    "- Single Inference function for all model providers via LiteLLM (call_llm)\n",
    "- Streaming **NOTE:** Bug when trying to stream in Gradio, but works directly in Notebook\n",
    "- Debug Tracing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "id": "b6c12598-4773-4f85-93ca-0128d74fbca0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from litellm import completion\n",
    "import gradio as gr\n",
    "from dotenv import load_dotenv\n",
    "from bs4 import BeautifulSoup\n",
    "import os\n",
    "import requests\n",
    "import json"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d24be370-5347-47fb-a58e-21a1b5409ab2",
   "metadata": {},
   "source": [
    "#### Load API Keys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e03afbe9-16aa-434c-a701-b3bfe75e927d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OpenAI API Key exists and begins sk-proj-\n",
      "Anthropic API Key exists and begins sk-ant-\n",
      "Google API Key exists and begins AIzaSyDC\n"
     ]
    }
   ],
   "source": [
    "# Load environment variables in a file called .env\n",
    "# Print the key prefixes to help with any debugging\n",
    "\n",
    "load_dotenv(override=True)\n",
    "openai_api_key = os.getenv('OPENAI_API_KEY')\n",
    "anthropic_api_key = os.getenv('ANTHROPIC_API_KEY')\n",
    "google_api_key = os.getenv('GEMINI_API_KEY')\n",
    "\n",
    "if openai_api_key:\n",
    "    print(f\"OpenAI API Key exists and begins {openai_api_key[:8]}\")\n",
    "else:\n",
    "    print(\"OpenAI API Key not set\")\n",
    "    \n",
    "if anthropic_api_key:\n",
    "    print(f\"Anthropic API Key exists and begins {anthropic_api_key[:7]}\")\n",
    "else:\n",
    "    print(\"Anthropic API Key not set\")\n",
    "\n",
    "if google_api_key:\n",
    "    print(f\"Google API Key exists and begins {google_api_key[:8]}\")\n",
    "   # import google.generativeai\n",
    "   # google.generativeai.configure()\n",
    "else:\n",
    "    print(\"Gemini API Key not set\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "66e46447-0e73-49ef-944a-d1e8fae4986e",
   "metadata": {},
   "source": [
    "### Use LiteLLM to abstract out the model provider"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "id": "473c2029-ca74-4f1e-92ac-05f7817ff7df",
   "metadata": {},
   "outputs": [],
   "source": [
    "def call_llm(model, system_prompt, user_prompt, json_format_response=False, streaming=False):\n",
    "    if DEBUG_OUTPUT:    \n",
    "        print(\"call_llm()\")\n",
    "        print(f\"streaming={streaming}\")\n",
    "        print(f\"json_format_response={json_format_response}\")\n",
    "    \n",
    "    messages = [\n",
    "        {\"role\": \"system\", \"content\": system_prompt},\n",
    "        {\"role\": \"user\", \"content\": user_prompt}\n",
    "    ]\n",
    "\n",
    "    payload = {\n",
    "        \"model\": model,\n",
    "        \"messages\": messages\n",
    "    }\n",
    "    # Use Json Reponse Format\n",
    "    # Link: https://docs.litellm.ai/docs/completion/json_mode\n",
    "    if json_format_response:\n",
    "        payload[\"response_format\"]: { \"type\": \"json_object\" }\n",
    "    \n",
    "    if streaming:\n",
    "        payload[\"stream\"] = True\n",
    "        response = completion(**payload)\n",
    "        # Return a generator expression instead of using yield in the function\n",
    "        return (part.choices[0].delta.content or \"\" for part in response)\n",
    "    else:\n",
    "        response = completion(**payload)\n",
    "        return response[\"choices\"][0][\"message\"][\"content\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f45e0972-a6a0-4237-8a69-e6f165f30e0d",
   "metadata": {},
   "source": [
    "### Brochure building functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "id": "c76d4ff9-0f18-49d0-a9b5-2c6c0bad359a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# A class to represent a Webpage\n",
    "\n",
    "# Some websites need you to use proper headers when fetching them:\n",
    "headers = {\n",
    " \"User-Agent\": \"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/117.0.0.0 Safari/537.36\"\n",
    "}\n",
    "\n",
    "class Website:\n",
    "    \"\"\"\n",
    "    A utility class to represent a Website that we have scraped, now with links\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, url):\n",
    "        self.url = url\n",
    "        response = requests.get(url, headers=headers)\n",
    "        self.body = response.content\n",
    "        soup = BeautifulSoup(self.body, 'html.parser')\n",
    "        self.title = soup.title.string if soup.title else \"No title found\"\n",
    "        if soup.body:\n",
    "            for irrelevant in soup.body([\"script\", \"style\", \"img\", \"input\"]):\n",
    "                irrelevant.decompose()\n",
    "            self.text = soup.body.get_text(separator=\"\\n\", strip=True)\n",
    "        else:\n",
    "            self.text = \"\"\n",
    "        links = [link.get('href') for link in soup.find_all('a')]\n",
    "        self.links = [link for link in links if link]\n",
    "\n",
    "    def get_contents(self):\n",
    "        return f\"Webpage Title:\\n{self.title}\\nWebpage Contents:\\n{self.text}\\n\\n\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "id": "ff41b687-3a46-4bca-a031-1148b91a4fdf",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_links(url, model):\n",
    "    if DEBUG_OUTPUT:\n",
    "        print(\"get_links()\")\n",
    "    website = Website(url)\n",
    "\n",
    "    link_system_prompt = \"You are provided with a list of links found on a webpage. \\\n",
    "    You are able to decide which of the links would be most relevant to include in a brochure about the company, \\\n",
    "    such as links to an About page, or a Company page, or Careers/Jobs pages.\\n\"\n",
    "    link_system_prompt += \"You should respond in raw JSON exactly as specified in this example. DO NOT USE MARKDOWN.\"\n",
    "    link_system_prompt += \"\"\"\n",
    "    {\n",
    "        \"links\": [\n",
    "            {\"type\": \"about page\", \"url\": \"https://full.url/goes/here/about\"},\n",
    "            {\"type\": \"careers page\": \"url\": \"https://another.full.url/careers\"}\n",
    "        ]\n",
    "    }\n",
    "    \"\"\"\n",
    "    \n",
    "    result = call_llm(model=model, \n",
    "                      system_prompt=link_system_prompt, \n",
    "                      user_prompt=get_links_user_prompt(website), \n",
    "                      json_format_response=True, \n",
    "                      streaming=False)\n",
    "    if DEBUG_OUTPUT:\n",
    "        print(result)\n",
    "    return json.loads(result)\n",
    "\n",
    "def get_links_user_prompt(website):\n",
    "    if DEBUG_OUTPUT:\n",
    "        print(\"get_links_user_prompt()\")\n",
    "        \n",
    "    user_prompt = f\"Here is the list of links on the website of {website.url} - \"\n",
    "    user_prompt += \"please decide which of these are relevant web links for a brochure about the company, respond with the full https URL in JSON format. \\\n",
    "Do not include Terms of Service, Privacy, email links.\\n\"\n",
    "    user_prompt += \"Links (some might be relative links):\\n\"\n",
    "    user_prompt += \"\\n\".join(website.links)\n",
    "\n",
    "    if DEBUG_OUTPUT:\n",
    "        print(user_prompt)\n",
    "    \n",
    "    return user_prompt\n",
    "\n",
    "def get_all_details(url, model):\n",
    "    if DEBUG_OUTPUT:\n",
    "        print(\"get_all_details()\")\n",
    "        \n",
    "    result = \"Landing page:\\n\"\n",
    "    result += Website(url).get_contents()\n",
    "    links = get_links(url, model)\n",
    "    if DEBUG_OUTPUT:\n",
    "        print(\"Found links:\", links)\n",
    "    for link in links[\"links\"]:\n",
    "        result += f\"\\n\\n{link['type']}\\n\"\n",
    "        result += Website(link[\"url\"]).get_contents()\n",
    "    return result\n",
    "\n",
    "def get_brochure_user_prompt(company_name, url, model):\n",
    "    \n",
    "    if DEBUG_OUTPUT:\n",
    "        print(\"get_brochure_user_prompt()\")\n",
    "    \n",
    "    user_prompt = f\"You are looking at a company called: {company_name}\\n\"\n",
    "    user_prompt += f\"Here are the contents of its landing page and other relevant pages; use this information to build a short brochure of the company in markdown.\\n\"\n",
    "    user_prompt += get_all_details(url, model)\n",
    "    user_prompt = user_prompt[:5000] # Truncate if more than 5,000 characters\n",
    "    return user_prompt\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "id": "cf7512a1-a498-44e8-a234-9affb72efe60",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_brochure(company_name, url, model, streaming):\n",
    "\n",
    "    system_prompt = \"You are an assistant that analyzes the contents of several relevant pages from a company website \\\n",
    "and creates a short brochure about the company for prospective customers, investors and recruits. Respond in markdown.\\\n",
    "Include details of company culture, customers and careers/jobs if you have the information.\"\n",
    "    if streaming:\n",
    "        result = call_llm(model=model, system_prompt=system_prompt, user_prompt=get_brochure_user_prompt(company_name, url, model), streaming=True)\n",
    "        return (p for p in result)\n",
    "    else:   \n",
    "        return call_llm(model=model, system_prompt=system_prompt, user_prompt=get_brochure_user_prompt(company_name, url, model), streaming=False)\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ecb6d212-ddb6-4170-81bf-8f3ea54479f8",
   "metadata": {},
   "source": [
    "#### Testing Model before implenting Gradio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "id": "de89843a-08ac-4431-8c83-21a93c05f764",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# Rio Tinto: Providing the Materials for a Sustainable Future\n",
      "\n",
      "## About Rio Tinto\n",
      "\n",
      "Rio Tinto is a global mining and metals company, operating in 35 countries with over 60,000 employees. Their purpose is to find better ways to provide the materials the world needs. Continuous improvement and innovation are at the core of their DNA, as they work to responsibly supply the metals and minerals critical for urbanization and the transition to a low-carbon economy.\n",
      "\n",
      "## Our Products\n",
      "\n",
      "Rio Tinto's diverse portfolio includes:\n",
      "\n",
      "- Iron Ore: The primary raw material used to make steel, which is strong, long-lasting and cost-efficient.\n",
      "- Aluminium: A lightweight, durable and recyclable metal.\n",
      "- Copper: A tough, malleable, corrosion-resistant and recyclable metal that is an excellent conductor of heat and electricity.\n",
      "- Lithium: The lightest of all metals, a key element for low-carbon technologies.\n",
      "- Diamonds: Ethically-sourced, high-quality diamonds.\n",
      "\n",
      "## Sustainability and Innovation\n",
      "\n",
      "Sustainability is at the heart of Rio Tinto's operations. They are targeting net zero emissions by 2050 and investing in nature-based solutions to complement their decarbonization efforts. Innovation is a key focus, with research and development into new technologies to improve efficiency and reduce environmental impact.\n",
      "\n",
      "## Careers and Culture\n",
      "\n",
      "Rio Tinto values its 60,000 employees and is committed to fostering a diverse and inclusive workplace. They offer a wide range of career opportunities, from mining and processing to engineering, finance, and more. Rio Tinto's culture is centered on safety, collaboration, and continuous improvement, with a strong emphasis on sustainability and responsible business practices.\n",
      "\n",
      "## Conclusion\n",
      "\n",
      "Rio Tinto is a global leader in the mining and metals industry, providing the materials essential for a sustainable future. Through their commitment to innovation, sustainability, and their talented workforce, Rio Tinto is well-positioned to meet the world's growing demand for critical resources.\n",
      "\u001b[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new\u001b[0m\n",
      "LiteLLM.Info: If you need to debug this error, use `litellm._turn_on_debug()'.\n",
      "\n",
      "<generator object call_llm.<locals>.<genexpr> at 0x7f80ca5da0c0>\n"
     ]
    }
   ],
   "source": [
    "MODEL=\"claude-3-haiku-20240307\"\n",
    "DEBUG_OUTPUT=False\n",
    "streaming=True\n",
    "result = create_brochure(company_name=\"Rio Tinto\", url=\"http://www.riotinto.com\", model=MODEL, streaming=streaming)\n",
    "\n",
    "if streaming:\n",
    "    for chunk in result:\n",
    "        print(chunk, end=\"\", flush=True)\n",
    "else:\n",
    "    print(result)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f330c92-6280-4dae-b4d8-717a56edb236",
   "metadata": {},
   "source": [
    "#### Gradio Setup\n",
    "Associate Dropdown values with the model we want to use.\n",
    "Link: https://www.gradio.app/docs/gradio/dropdown#initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2f38862-3728-4bba-9e16-6f9fab276145",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "DEBUG_OUTPUT=True\n",
    "view = gr.Interface(\n",
    "    fn=create_brochure,\n",
    "    inputs=[\n",
    "        gr.Textbox(label=\"Company name:\"),\n",
    "        gr.Textbox(label=\"Landing page URL including http:// or https://\"),\n",
    "        gr.Dropdown(choices=[(\"GPT 4o Mini\", \"gpt-4o-mini\"), \n",
    "                             (\"Claude Haiku 3\", \"claude-3-haiku-20240307\"), \n",
    "                             (\"Gemini 2.0 Flash\", \"gemini/gemini-2.0-flash\")], \n",
    "                    label=\"Select model\"),\n",
    "        gr.Checkbox(label=\"Stream\")\n",
    "    ],\n",
    "    outputs=[gr.Markdown(label=\"Brochure:\")],\n",
    "    flagging_mode=\"never\"\n",
    ")\n",
    "view.launch()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0981136-2067-43b8-b17d-83560dd609ce",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
